06. Transfer Learning with TensorFlow Part 3: Scaling up (🍔👁 Food Vision mini)¶

In the previous two notebooks (transfer learning part 1: feature extraction and part 2: fine-tuning) we've seen the power of transfer learning.

Now we know our smaller modelling experiments are working, it's time to step things up a notch with more data.

This is a common practice in machine learning and deep learning: get a model working on a small amount of data before scaling it up to a larger amount of data.

🔑 Note: You haven't forgotten the machine learning practitioners motto have you? "Experiment, experiment, experiment."

It's time to get closer to our Food Vision project coming to life. In this notebook we're going to scale up from using 10 classes of the Food101 data to using all of the classes in the Food101 dataset.

Our goal is to beat the original Food101 paper's results with 10% of data.

Machine learning practitioners are serial experimenters. Start small, get a model working, see if your experiments work then gradually scale them up to where you want to go (we're going to be looking at scaling up throughout this notebook).

What we're going to cover¶

We're going to go through the follow with TensorFlow:

  • Downloading and preparing 10% of the Food101 data (10% of training data)
  • Training a feature extraction transfer learning model on 10% of the Food101 training data
  • Fine-tuning our feature extraction model
  • Saving and loaded our trained model
  • Evaluating the performance of our Food Vision model trained on 10% of the training data
    • Finding our model's most wrong predictions
  • Making predictions with our Food Vision model on custom images of food

How you can use this notebook¶

You can read through the descriptions and the code (it should all run, except for the cells which error on purpose), but there's a better option.

Write all of the code yourself.

Yes. I'm serious. Create a new notebook, and rewrite each line by yourself. Investigate it, see if you can break it, why does it break?

You don't have to write the text descriptions but writing the code yourself is a great way to get hands-on experience.

Don't worry if you make mistakes, we all do. The way to get better and make less mistakes is to write more code.

📖 Resource: See the full set of course materials on GitHub: https://github.com/mrdbourke/tensorflow-deep-learning

In [1]:
# Are we using a GPU?
!nvidia-smi
Sat Jul 30 21:11:04 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.48.07    Driver Version: 515.48.07    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  Off  | 00000000:01:00.0 Off |                  N/A |
|  0%   56C    P0    66W / 170W |    450MiB /  4096MiB |      4%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  Off  | 00000000:02:00.0 Off |                  N/A |
|  0%   47C    P8    14W / 170W |      6MiB /  4096MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A      1416      G   /usr/lib/xorg/Xorg                268MiB |
|    0   N/A  N/A      1625      G   ...ome-remote-desktop-daemon        2MiB |
|    0   N/A  N/A      1683      G   /usr/bin/gnome-shell               56MiB |
|    0   N/A  N/A      3107      G   ...754728995500944683,131072      116MiB |
|    1   N/A  N/A      1416      G   /usr/lib/xorg/Xorg                  3MiB |
+-----------------------------------------------------------------------------+

Creating helper functions¶

We've created a series of helper functions throughout the previous notebooks. Instead of rewriting them (tedious), we'll import the helper_functions.py file from the GitHub repo.

In [2]:
# Get helper functions file
!wget https://raw.githubusercontent.com/mrdbourke/tensorflow-deep-learning/main/extras/helper_functions.py
--2022-07-30 16:38:29--  https://raw.githubusercontent.com/mrdbourke/tensorflow-deep-learning/main/extras/helper_functions.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.111.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 10246 (10K) [text/plain]
Saving to: ‘helper_functions.py.1’

helper_functions.py 100%[===================>]  10.01K  --.-KB/s    in 0s      

2022-07-30 16:38:29 (98.5 MB/s) - ‘helper_functions.py.1’ saved [10246/10246]

In [3]:
# Import series of helper functions for the notebook (we've created/used these in previous notebooks)
from helper_functions import create_tensorboard_callback, plot_loss_curves, unzip_data, compare_historys, walk_through_dir

101 Food Classes: Working with less data¶

So far we've confirmed the transfer learning model's we've been using work pretty well with the 10 Food Classes dataset. Now it's time to step it up and see how they go with the full 101 Food Classes.

In the original Food101 dataset there's 1000 images per class (750 of each class in the training set and 250 of each class in the test set), totalling 101,000 images.

We could start modelling straight away on this large dataset but in the spirit of continually experimenting, we're going to see how our previously working model's go with 10% of the training data.

This means for each of the 101 food classes we'll be building a model on 75 training images and evaluating it on 250 test images.

Downloading and preprocessing the data¶

Just as before we'll download a subset of the Food101 dataset which has been extracted from the original dataset (to see the preprocessing of the data check out the Food Vision preprocessing notebook).

We download the data as a zip file so we'll use our unzip_data() function to unzip it.

In [4]:
# Download data from Google Storage (already preformatted)
!wget https://storage.googleapis.com/ztm_tf_course/food_vision/101_food_classes_10_percent.zip 

unzip_data("101_food_classes_10_percent.zip")

train_dir = "101_food_classes_10_percent/train/"
test_dir = "101_food_classes_10_percent/test/"
--2022-07-30 16:38:38--  https://storage.googleapis.com/ztm_tf_course/food_vision/101_food_classes_10_percent.zip
Resolving storage.googleapis.com (storage.googleapis.com)... 142.250.125.128, 74.125.124.128, 172.217.212.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.250.125.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1625420029 (1.5G) [application/zip]
Saving to: ‘101_food_classes_10_percent.zip.1’

101_food_classes_10 100%[===================>]   1.51G   241MB/s    in 6.0s    

2022-07-30 16:38:44 (257 MB/s) - ‘101_food_classes_10_percent.zip.1’ saved [1625420029/1625420029]

In [5]:
# How many images/classes are there?
walk_through_dir("101_food_classes_10_percent")
There are 2 directories and 0 images in '101_food_classes_10_percent'.
There are 101 directories and 0 images in '101_food_classes_10_percent/test'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/cannoli'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/cheese_plate'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/takoyaki'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/frozen_yogurt'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/paella'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/chocolate_mousse'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/grilled_salmon'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/french_onion_soup'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/breakfast_burrito'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/scallops'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/dumplings'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/hamburger'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/filet_mignon'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/bread_pudding'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/poutine'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/creme_brulee'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/spaghetti_carbonara'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/chicken_quesadilla'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/ravioli'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/mussels'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/ramen'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/pad_thai'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/red_velvet_cake'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/seaweed_salad'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/hot_dog'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/macaroni_and_cheese'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/strawberry_shortcake'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/edamame'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/miso_soup'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/beef_carpaccio'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/baklava'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/ice_cream'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/carrot_cake'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/garlic_bread'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/fish_and_chips'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/samosa'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/prime_rib'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/lobster_bisque'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/tuna_tartare'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/chicken_wings'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/tiramisu'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/hot_and_sour_soup'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/pizza'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/club_sandwich'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/bruschetta'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/pork_chop'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/fried_rice'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/oysters'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/tacos'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/cup_cakes'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/sashimi'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/deviled_eggs'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/churros'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/onion_rings'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/huevos_rancheros'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/peking_duck'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/foie_gras'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/caesar_salad'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/baby_back_ribs'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/donuts'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/greek_salad'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/ceviche'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/crab_cakes'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/panna_cotta'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/sushi'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/grilled_cheese_sandwich'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/falafel'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/beef_tartare'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/lasagna'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/pancakes'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/spring_rolls'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/gnocchi'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/french_fries'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/chicken_curry'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/lobster_roll_sandwich'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/french_toast'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/beignets'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/nachos'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/escargots'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/caprese_salad'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/guacamole'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/pulled_pork_sandwich'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/fried_calamari'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/gyoza'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/beet_salad'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/bibimbap'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/risotto'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/clam_chowder'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/omelette'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/eggs_benedict'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/pho'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/steak'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/waffles'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/hummus'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/chocolate_cake'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/apple_pie'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/macarons'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/croque_madame'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/cheesecake'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/spaghetti_bolognese'.
There are 0 directories and 250 images in '101_food_classes_10_percent/test/shrimp_and_grits'.
There are 101 directories and 0 images in '101_food_classes_10_percent/train'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/cannoli'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/cheese_plate'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/takoyaki'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/frozen_yogurt'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/paella'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/chocolate_mousse'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/grilled_salmon'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/french_onion_soup'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/breakfast_burrito'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/scallops'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/dumplings'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/hamburger'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/filet_mignon'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/bread_pudding'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/poutine'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/creme_brulee'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/spaghetti_carbonara'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/chicken_quesadilla'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/ravioli'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/mussels'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/ramen'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/pad_thai'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/red_velvet_cake'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/seaweed_salad'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/hot_dog'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/macaroni_and_cheese'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/strawberry_shortcake'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/edamame'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/miso_soup'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/beef_carpaccio'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/baklava'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/ice_cream'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/carrot_cake'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/garlic_bread'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/fish_and_chips'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/samosa'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/prime_rib'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/lobster_bisque'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/tuna_tartare'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/chicken_wings'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/tiramisu'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/hot_and_sour_soup'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/pizza'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/club_sandwich'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/bruschetta'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/pork_chop'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/fried_rice'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/oysters'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/tacos'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/cup_cakes'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/sashimi'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/deviled_eggs'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/churros'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/onion_rings'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/huevos_rancheros'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/peking_duck'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/foie_gras'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/caesar_salad'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/baby_back_ribs'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/donuts'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/greek_salad'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/ceviche'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/crab_cakes'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/panna_cotta'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/sushi'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/grilled_cheese_sandwich'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/falafel'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/beef_tartare'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/lasagna'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/pancakes'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/spring_rolls'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/gnocchi'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/french_fries'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/chicken_curry'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/lobster_roll_sandwich'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/french_toast'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/beignets'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/nachos'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/escargots'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/caprese_salad'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/guacamole'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/pulled_pork_sandwich'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/fried_calamari'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/gyoza'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/beet_salad'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/bibimbap'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/risotto'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/clam_chowder'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/omelette'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/eggs_benedict'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/pho'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/steak'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/waffles'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/hummus'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/chocolate_cake'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/apple_pie'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/macarons'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/croque_madame'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/cheesecake'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/spaghetti_bolognese'.
There are 0 directories and 75 images in '101_food_classes_10_percent/train/shrimp_and_grits'.

As before our data comes in the common image classification data format of:

Example of file structure

10_food_classes_10_percent <- top level folder
└───train <- training images
│   └───pizza
│   │   │   1008104.jpg
│   │   │   1638227.jpg
│   │   │   ...      
│   └───steak
│       │   1000205.jpg
│       │   1647351.jpg
│       │   ...
│   
└───test <- testing images
│   └───pizza
│   │   │   1001116.jpg
│   │   │   1507019.jpg
│   │   │   ...      
│   └───steak
│       │   100274.jpg
│       │   1653815.jpg
│       │   ...

Let's use the image_dataset_from_directory() function to turn our images and labels into a tf.data.Dataset, a TensorFlow datatype which allows for us to pass it directory to our model.

For the test dataset, we're going to set shuffle=False so we can perform repeatable evaluation and visualization on it later.

In [6]:
# Setup data inputs
import tensorflow as tf
IMG_SIZE = (224, 224)
train_data_all_10_percent = tf.keras.preprocessing.image_dataset_from_directory(train_dir,
                                                                                label_mode="categorical",
                                                                                image_size=IMG_SIZE)
                                                                                
test_data = tf.keras.preprocessing.image_dataset_from_directory(test_dir,
                                                                label_mode="categorical",
                                                                image_size=IMG_SIZE,
                                                                shuffle=False) # don't shuffle test data for prediction analysis
Found 7575 files belonging to 101 classes.
Found 25250 files belonging to 101 classes.

Wonderful! It looks like our data has been imported as expected with 75 images per class in the training set (75 images 101 classes = 7575 images) and 25250 images in the test set (250 images 101 classes = 25250 images).

Train a big dog model with transfer learning on 10% of 101 food classes¶

Our food image data has been imported into TensorFlow, time to model it.

To keep our experiments swift, we're going to start by using feature extraction transfer learning with a pre-trained model for a few epochs and then fine-tune for a few more epochs.

More specifically, our goal will be to see if we can beat the baseline from original Food101 paper (50.76% accuracy on 101 classes) with 10% of the training data and the following modelling setup:

  • A ModelCheckpoint callback to save our progress during training, this means we could experiment with further training later without having to train from scratch every time
  • Data augmentation built right into the model
  • A headless (no top layers) EfficientNetB0 architecture from tf.keras.applications as our base model
  • A Dense layer with 101 hidden neurons (same as number of food classes) and softmax activation as the output layer
  • Categorical crossentropy as the loss function since we're dealing with more than two classes
  • The Adam optimizer with the default settings
  • Fitting for 5 full passes on the training data while evaluating on 15% of the test data

It seems like a lot but these are all things we've covered before in the Transfer Learning in TensorFlow Part 2: Fine-tuning notebook.

Let's start by creating the ModelCheckpoint callback.

Since we want our model to perform well on unseen data we'll set it to monitor the validation accuracy metric and save the model weights which score the best on that.

In [7]:
# Create checkpoint callback to save model for later use
checkpoint_path = "101_classes_10_percent_data_model_checkpoint"
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
                                                         save_weights_only=True, # save only the model weights
                                                         monitor="val_accuracy", # save the model weights which score the best validation accuracy
                                                         save_best_only=True) # only keep the best model weights on file (delete the rest)

Checkpoint ready. Now let's create a small data augmentation model with the Sequential API. Because we're working with a reduced sized training set, this will help prevent our model from overfitting on the training data.

In [8]:
# Import the required modules for model creation
from tensorflow.keras import layers
from tensorflow.keras.layers.experimental import preprocessing
from tensorflow.keras.models import Sequential

# Setup data augmentation
data_augmentation = Sequential([
  preprocessing.RandomFlip("horizontal"), # randomly flip images on horizontal edge
  preprocessing.RandomRotation(0.2), # randomly rotate images by a specific amount
  preprocessing.RandomHeight(0.2), # randomly adjust the height of an image by a specific amount
  preprocessing.RandomWidth(0.2), # randomly adjust the width of an image by a specific amount
  preprocessing.RandomZoom(0.2), # randomly zoom into an image
  # preprocessing.Rescaling(1./255) # keep for models like ResNet50V2, remove for EfficientNet
], name="data_augmentation")

Beautiful! We'll be able to insert the data_augmentation Sequential model as a layer in our Functional API model. That way if we want to continue training our model at a later time, the data augmentation is already built right in.

Speaking of Functional API model's, time to put together a feature extraction transfer learning model using tf.keras.applications.EfficientNetB0 as our base model.

We'll import the base model using the parameter include_top=False so we can add on our own output layers, notably GlobalAveragePooling2D() (condense the outputs of the base model into a shape usable by the output layer) followed by a Dense layer.

In [9]:
# Setup base model and freeze its layers (this will extract features)
base_model = tf.keras.applications.EfficientNetB0(include_top=False)
base_model.trainable = False

# Setup model architecture with trainable top layers
inputs = layers.Input(shape=(224, 224, 3), name="input_layer") # shape of input image
x = data_augmentation(inputs) # augment images (only happens during training)
x = base_model(x, training=False) # put the base model in inference mode so we can use it to extract features without updating the weights
x = layers.GlobalAveragePooling2D(name="global_average_pooling")(x) # pool the outputs of the base model
outputs = layers.Dense(len(train_data_all_10_percent.class_names), activation="softmax", name="output_layer")(x) # same number of outputs as classes
model = tf.keras.Model(inputs, outputs)

A colourful figure of the model we've created with: 224x224 images as input, data augmentation as a layer, EfficientNetB0 as a backbone, an averaging pooling layer as well as dense layer with 10 neurons (same as number of classes we're working with) as output.

Model created. Let's inspect it.

In [10]:
# Get a summary of our model
model.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_layer (InputLayer)    [(None, 224, 224, 3)]     0         
                                                                 
 data_augmentation (Sequenti  (None, 224, 224, 3)      0         
 al)                                                             
                                                                 
 efficientnetb0 (Functional)  (None, None, None, 1280)  4049571  
                                                                 
 global_average_pooling (Glo  (None, 1280)             0         
 balAveragePooling2D)                                            
                                                                 
 output_layer (Dense)        (None, 101)               129381    
                                                                 
=================================================================
Total params: 4,178,952
Trainable params: 129,381
Non-trainable params: 4,049,571
_________________________________________________________________

Looking good! Our Functional model has 5 layers but each of those layers have varying amounts of layers within them.

Notice the number of trainable and non-trainable parameters. It seems the only trainable parameters are within the output_layer which is exactly what we're after with this first run of feature extraction; keep all the learned patterns in the base model (EfficientNetb0) frozen whilst allowing the model to tune its outputs to our custom data.

Time to compile and fit.

In [11]:
# Compile
model.compile(loss="categorical_crossentropy",
              optimizer=tf.keras.optimizers.Adam(), # use Adam with default settings
              metrics=["accuracy"])

# Fit
history_all_classes_10_percent = model.fit(train_data_all_10_percent,
                                           epochs=5, # fit for 5 epochs to keep experiments quick
                                           validation_data=test_data,
                                           validation_steps=int(0.15 * len(test_data)), # evaluate on smaller portion of test data
                                           callbacks=[checkpoint_callback]) # save best model weights to file
Epoch 1/5
237/237 [==============================] - 92s 317ms/step - loss: 3.4741 - accuracy: 0.2450 - val_loss: 2.7287 - val_accuracy: 0.3874
Epoch 2/5
237/237 [==============================] - 58s 242ms/step - loss: 2.3539 - accuracy: 0.4590 - val_loss: 2.3206 - val_accuracy: 0.4412
Epoch 3/5
237/237 [==============================] - 54s 225ms/step - loss: 1.9770 - accuracy: 0.5286 - val_loss: 2.1639 - val_accuracy: 0.4550
Epoch 4/5
237/237 [==============================] - 49s 206ms/step - loss: 1.7602 - accuracy: 0.5721 - val_loss: 2.0784 - val_accuracy: 0.4695
Epoch 5/5
237/237 [==============================] - 49s 205ms/step - loss: 1.6035 - accuracy: 0.6044 - val_loss: 2.0258 - val_accuracy: 0.4801

Woah! It looks like our model is getting some impressive results, but remember, during training our model only evaluated on 15% of the test data. Let's see how it did on the whole test dataset.

In [12]:
# Evaluate model 
results_feature_extraction_model = model.evaluate(test_data)
results_feature_extraction_model
790/790 [==============================] - 61s 77ms/step - loss: 1.7081 - accuracy: 0.5573
Out[12]:
[1.7080934047698975, 0.557346522808075]

Well it looks like we just beat our baseline (the results from the original Food101 paper) with 10% of the data! In under 5-minutes... that's the power of deep learning and more precisely, transfer learning: leveraging what one model has learned on another dataset for our own dataset.

How do the loss curves look?

In [13]:
plot_loss_curves(history_all_classes_10_percent)

🤔 Question: What do these curves suggest? Hint: ideally, the two curves should be very similar to each other, if not, there may be some overfitting or underfitting.

Fine-tuning¶

Our feature extraction transfer learning model is performing well. Why don't we try to fine-tune a few layers in the base model and see if we gain any improvements?

The good news is, thanks to the ModelCheckpoint callback, we've got the saved weights of our already well-performing model so if fine-tuning doesn't add any benefits, we can revert back.

To fine-tune the base model we'll first set its trainable attribute to True, unfreezing all of the frozen.

Then since we've got a relatively small training dataset, we'll refreeze every layer except for the last 5, making them trainable.

In [14]:
# Unfreeze all of the layers in the base model
base_model.trainable = True

# Refreeze every layer except for the last 5
for layer in base_model.layers[:-5]:
  layer.trainable = False

We just made a change to the layers in our model and what do we have to do every time we make a change to our model?

Recompile it.

Because we're fine-tuning, we'll use a 10x lower learning rate to ensure the updates to the previous trained weights aren't too large.

When fine-tuning and unfreezing layers of your pre-trained model, it's common practice to lower the learning rate you used for your feature extraction model. How much by? A 10x lower learning rate is usually a good place to to start.

In [15]:
# Recompile model with lower learning rate
model.compile(loss='categorical_crossentropy',
              optimizer=tf.keras.optimizers.Adam(1e-4), # 10x lower learning rate than default
              metrics=['accuracy'])

Model recompiled, how about we make sure the layers we want are trainable?

In [16]:
# What layers in the model are trainable?
for layer in model.layers:
  print(layer.name, layer.trainable)
input_layer True
data_augmentation True
efficientnetb0 True
global_average_pooling True
output_layer True
In [17]:
# Check which layers are trainable
for layer_number, layer in enumerate(base_model.layers):
  print(layer_number, layer.name, layer.trainable)
0 input_1 False
1 rescaling False
2 normalization False
3 stem_conv_pad False
4 stem_conv False
5 stem_bn False
6 stem_activation False
7 block1a_dwconv False
8 block1a_bn False
9 block1a_activation False
10 block1a_se_squeeze False
11 block1a_se_reshape False
12 block1a_se_reduce False
13 block1a_se_expand False
14 block1a_se_excite False
15 block1a_project_conv False
16 block1a_project_bn False
17 block2a_expand_conv False
18 block2a_expand_bn False
19 block2a_expand_activation False
20 block2a_dwconv_pad False
21 block2a_dwconv False
22 block2a_bn False
23 block2a_activation False
24 block2a_se_squeeze False
25 block2a_se_reshape False
26 block2a_se_reduce False
27 block2a_se_expand False
28 block2a_se_excite False
29 block2a_project_conv False
30 block2a_project_bn False
31 block2b_expand_conv False
32 block2b_expand_bn False
33 block2b_expand_activation False
34 block2b_dwconv False
35 block2b_bn False
36 block2b_activation False
37 block2b_se_squeeze False
38 block2b_se_reshape False
39 block2b_se_reduce False
40 block2b_se_expand False
41 block2b_se_excite False
42 block2b_project_conv False
43 block2b_project_bn False
44 block2b_drop False
45 block2b_add False
46 block3a_expand_conv False
47 block3a_expand_bn False
48 block3a_expand_activation False
49 block3a_dwconv_pad False
50 block3a_dwconv False
51 block3a_bn False
52 block3a_activation False
53 block3a_se_squeeze False
54 block3a_se_reshape False
55 block3a_se_reduce False
56 block3a_se_expand False
57 block3a_se_excite False
58 block3a_project_conv False
59 block3a_project_bn False
60 block3b_expand_conv False
61 block3b_expand_bn False
62 block3b_expand_activation False
63 block3b_dwconv False
64 block3b_bn False
65 block3b_activation False
66 block3b_se_squeeze False
67 block3b_se_reshape False
68 block3b_se_reduce False
69 block3b_se_expand False
70 block3b_se_excite False
71 block3b_project_conv False
72 block3b_project_bn False
73 block3b_drop False
74 block3b_add False
75 block4a_expand_conv False
76 block4a_expand_bn False
77 block4a_expand_activation False
78 block4a_dwconv_pad False
79 block4a_dwconv False
80 block4a_bn False
81 block4a_activation False
82 block4a_se_squeeze False
83 block4a_se_reshape False
84 block4a_se_reduce False
85 block4a_se_expand False
86 block4a_se_excite False
87 block4a_project_conv False
88 block4a_project_bn False
89 block4b_expand_conv False
90 block4b_expand_bn False
91 block4b_expand_activation False
92 block4b_dwconv False
93 block4b_bn False
94 block4b_activation False
95 block4b_se_squeeze False
96 block4b_se_reshape False
97 block4b_se_reduce False
98 block4b_se_expand False
99 block4b_se_excite False
100 block4b_project_conv False
101 block4b_project_bn False
102 block4b_drop False
103 block4b_add False
104 block4c_expand_conv False
105 block4c_expand_bn False
106 block4c_expand_activation False
107 block4c_dwconv False
108 block4c_bn False
109 block4c_activation False
110 block4c_se_squeeze False
111 block4c_se_reshape False
112 block4c_se_reduce False
113 block4c_se_expand False
114 block4c_se_excite False
115 block4c_project_conv False
116 block4c_project_bn False
117 block4c_drop False
118 block4c_add False
119 block5a_expand_conv False
120 block5a_expand_bn False
121 block5a_expand_activation False
122 block5a_dwconv False
123 block5a_bn False
124 block5a_activation False
125 block5a_se_squeeze False
126 block5a_se_reshape False
127 block5a_se_reduce False
128 block5a_se_expand False
129 block5a_se_excite False
130 block5a_project_conv False
131 block5a_project_bn False
132 block5b_expand_conv False
133 block5b_expand_bn False
134 block5b_expand_activation False
135 block5b_dwconv False
136 block5b_bn False
137 block5b_activation False
138 block5b_se_squeeze False
139 block5b_se_reshape False
140 block5b_se_reduce False
141 block5b_se_expand False
142 block5b_se_excite False
143 block5b_project_conv False
144 block5b_project_bn False
145 block5b_drop False
146 block5b_add False
147 block5c_expand_conv False
148 block5c_expand_bn False
149 block5c_expand_activation False
150 block5c_dwconv False
151 block5c_bn False
152 block5c_activation False
153 block5c_se_squeeze False
154 block5c_se_reshape False
155 block5c_se_reduce False
156 block5c_se_expand False
157 block5c_se_excite False
158 block5c_project_conv False
159 block5c_project_bn False
160 block5c_drop False
161 block5c_add False
162 block6a_expand_conv False
163 block6a_expand_bn False
164 block6a_expand_activation False
165 block6a_dwconv_pad False
166 block6a_dwconv False
167 block6a_bn False
168 block6a_activation False
169 block6a_se_squeeze False
170 block6a_se_reshape False
171 block6a_se_reduce False
172 block6a_se_expand False
173 block6a_se_excite False
174 block6a_project_conv False
175 block6a_project_bn False
176 block6b_expand_conv False
177 block6b_expand_bn False
178 block6b_expand_activation False
179 block6b_dwconv False
180 block6b_bn False
181 block6b_activation False
182 block6b_se_squeeze False
183 block6b_se_reshape False
184 block6b_se_reduce False
185 block6b_se_expand False
186 block6b_se_excite False
187 block6b_project_conv False
188 block6b_project_bn False
189 block6b_drop False
190 block6b_add False
191 block6c_expand_conv False
192 block6c_expand_bn False
193 block6c_expand_activation False
194 block6c_dwconv False
195 block6c_bn False
196 block6c_activation False
197 block6c_se_squeeze False
198 block6c_se_reshape False
199 block6c_se_reduce False
200 block6c_se_expand False
201 block6c_se_excite False
202 block6c_project_conv False
203 block6c_project_bn False
204 block6c_drop False
205 block6c_add False
206 block6d_expand_conv False
207 block6d_expand_bn False
208 block6d_expand_activation False
209 block6d_dwconv False
210 block6d_bn False
211 block6d_activation False
212 block6d_se_squeeze False
213 block6d_se_reshape False
214 block6d_se_reduce False
215 block6d_se_expand False
216 block6d_se_excite False
217 block6d_project_conv False
218 block6d_project_bn False
219 block6d_drop False
220 block6d_add False
221 block7a_expand_conv False
222 block7a_expand_bn False
223 block7a_expand_activation False
224 block7a_dwconv False
225 block7a_bn False
226 block7a_activation False
227 block7a_se_squeeze False
228 block7a_se_reshape False
229 block7a_se_reduce False
230 block7a_se_expand False
231 block7a_se_excite False
232 block7a_project_conv True
233 block7a_project_bn True
234 top_conv True
235 top_bn True
236 top_activation True

Excellent! Time to fine-tune our model.

Another 5 epochs should be enough to see whether any benefits come about (though we could always try more).

We'll start the training off where the feature extraction model left off using the initial_epoch parameter in the fit() function.

In [18]:
# Fine-tune for 5 more epochs
fine_tune_epochs = 10 # model has already done 5 epochs, this is the total number of epochs we're after (5+5=10)

history_all_classes_10_percent_fine_tune = model.fit(train_data_all_10_percent,
                                                     epochs=fine_tune_epochs,
                                                     validation_data=test_data,
                                                     validation_steps=int(0.15 * len(test_data)), # validate on 15% of the test data
                                                     initial_epoch=history_all_classes_10_percent.epoch[-1]) # start from previous last epoch
Epoch 5/10
237/237 [==============================] - 54s 200ms/step - loss: 1.3445 - accuracy: 0.6552 - val_loss: 1.9766 - val_accuracy: 0.4971
Epoch 6/10
237/237 [==============================] - 44s 183ms/step - loss: 1.2294 - accuracy: 0.6747 - val_loss: 1.9722 - val_accuracy: 0.4928
Epoch 7/10
237/237 [==============================] - 44s 184ms/step - loss: 1.1433 - accuracy: 0.6964 - val_loss: 1.9629 - val_accuracy: 0.4997
Epoch 8/10
237/237 [==============================] - 41s 171ms/step - loss: 1.0688 - accuracy: 0.7156 - val_loss: 2.0072 - val_accuracy: 0.4918
Epoch 9/10
237/237 [==============================] - 41s 169ms/step - loss: 1.0142 - accuracy: 0.7278 - val_loss: 1.9400 - val_accuracy: 0.5069
Epoch 10/10
237/237 [==============================] - 41s 172ms/step - loss: 0.9623 - accuracy: 0.7423 - val_loss: 1.9723 - val_accuracy: 0.4955

Once again, during training we were only evaluating on a small portion of the test data, let's find out how our model went on all of the test data.

In [19]:
# Evaluate fine-tuned model on the whole test dataset
results_all_classes_10_percent_fine_tune = model.evaluate(test_data)
results_all_classes_10_percent_fine_tune
790/790 [==============================] - 59s 75ms/step - loss: 1.6230 - accuracy: 0.5723
Out[19]:
[1.6229630708694458, 0.5723168253898621]

Hmm... it seems like our model got a slight boost from fine-tuning.

We might get a better picture by using our compare_historys() function and seeing what the training curves say.

In [20]:
compare_historys(original_history=history_all_classes_10_percent,
                 new_history=history_all_classes_10_percent_fine_tune,
                 initial_epochs=5)

It seems that after fine-tuning, our model's training metrics improved significantly but validation, not so much. Looks like our model is starting to overfit.

This is okay though, its very often the case that fine-tuning leads to overfitting when the data a pre-trained model has been trained on is similar to your custom data.

In our case, our pre-trained model, EfficientNetB0 was trained on ImageNet which contains many real life pictures of food just like our food dataset.

If feature extraction already works well, the improvements you see from fine-tuning may not be as great as if your dataset was significantly different from the data your base model was pre-trained on.

Saving our trained model¶

To prevent having to retrain our model from scratch, let's save it to file using the save() method.

In [21]:
# # Save model to drive so it can be used later 
#from google.colab import drive
#drive.mount('/content/drive')

#model.save("drive/MyDrive/tensorflow_course/101_food_class_10_percent_saved_big_dog_model")
INFO:tensorflow:Assets written to: drive/MyDrive/tensorflow_course/101_food_class_10_percent_saved_big_dog_model/assets

Evaluating the performance of the big dog model across all different classes¶

We've got a trained and saved model which according to the evaluation metrics we've used is performing fairly well.

But metrics schmetrics, let's dive a little deeper into our model's performance and get some visualizations going.

To do so, we'll load in the saved model and use it to make some predictions on the test dataset.

🔑 Note: Evaluating a machine learning model is as important as training one. Metrics can be deceiving. You should always visualize your model's performance on unseen data to make sure you aren't being fooled good looking training numbers.

In [22]:
import tensorflow as tf

# Download pre-trained model from Google Storage (like a cooking show, I trained this model earlier, so the results may be different than above)
!wget https://storage.googleapis.com/ztm_tf_course/food_vision/06_101_food_class_10_percent_saved_big_dog_model.zip
saved_model_path = "06_101_food_class_10_percent_saved_big_dog_model.zip"
unzip_data(saved_model_path)

# Note: loading a model will output a lot of 'WARNINGS', these can be ignored: https://www.tensorflow.org/tutorials/keras/save_and_load#save_checkpoints_during_training
# There's also a thread on GitHub trying to fix these warnings: https://github.com/tensorflow/tensorflow/issues/40166
# model = tf.keras.models.load_model("drive/My Drive/tensorflow_course/101_food_class_10_percent_saved_big_dog_model/") # path to drive model
model = tf.keras.models.load_model(saved_model_path.split(".")[0]) # don't include ".zip" in loaded model path
--2022-07-30 16:57:01--  https://storage.googleapis.com/ztm_tf_course/food_vision/06_101_food_class_10_percent_saved_big_dog_model.zip
Resolving storage.googleapis.com (storage.googleapis.com)... 142.250.128.128, 142.251.6.128, 142.250.152.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.250.128.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 46760742 (45M) [application/zip]
Saving to: ‘06_101_food_class_10_percent_saved_big_dog_model.zip.1’

06_101_food_class_1 100%[===================>]  44.59M  42.3MB/s    in 1.1s    

2022-07-30 16:57:03 (42.3 MB/s) - ‘06_101_food_class_10_percent_saved_big_dog_model.zip.1’ saved [46760742/46760742]

WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), *NOT* tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.
WARNING:absl:Importing a function (__inference_block6c_expand_activation_layer_call_and_return_conditional_losses_419470) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_efficientnetb0_layer_call_and_return_conditional_losses_446460) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2a_activation_layer_call_and_return_conditional_losses_450449) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2a_expand_activation_layer_call_and_return_conditional_losses_415747) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2b_activation_layer_call_and_return_conditional_losses_416083) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2b_activation_layer_call_and_return_conditional_losses_450775) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4a_activation_layer_call_and_return_conditional_losses_451847) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5a_expand_activation_layer_call_and_return_conditional_losses_417915) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4a_se_reduce_layer_call_and_return_conditional_losses_451887) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4c_expand_activation_layer_call_and_return_conditional_losses_452467) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_functional_17_layer_call_and_return_conditional_losses_438312) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4c_expand_activation_layer_call_and_return_conditional_losses_417583) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5c_activation_layer_call_and_return_conditional_losses_418582) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6a_se_reduce_layer_call_and_return_conditional_losses_454031) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block7a_activation_layer_call_and_return_conditional_losses_455436) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block1a_activation_layer_call_and_return_conditional_losses_415524) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3b_activation_layer_call_and_return_conditional_losses_451474) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4a_expand_activation_layer_call_and_return_conditional_losses_451768) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_efficientnetb0_layer_call_and_return_conditional_losses_441729) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6b_se_reduce_layer_call_and_return_conditional_losses_454357) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3b_activation_layer_call_and_return_conditional_losses_416695) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6b_expand_activation_layer_call_and_return_conditional_losses_454238) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_functional_17_layer_call_and_return_conditional_losses_436681) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2a_activation_layer_call_and_return_conditional_losses_415804) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5a_activation_layer_call_and_return_conditional_losses_452919) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5c_se_reduce_layer_call_and_return_conditional_losses_453658) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_efficientnetb0_layer_call_and_return_conditional_losses_448082) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6a_activation_layer_call_and_return_conditional_losses_418915) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5c_expand_activation_layer_call_and_return_conditional_losses_453539) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4c_se_reduce_layer_call_and_return_conditional_losses_452586) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block1a_se_reduce_layer_call_and_return_conditional_losses_450163) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5a_se_reduce_layer_call_and_return_conditional_losses_418018) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block7a_expand_activation_layer_call_and_return_conditional_losses_455357) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4c_activation_layer_call_and_return_conditional_losses_417639) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3a_se_reduce_layer_call_and_return_conditional_losses_451188) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block7a_activation_layer_call_and_return_conditional_losses_420190) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_stem_activation_layer_call_and_return_conditional_losses_415468) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block7a_se_reduce_layer_call_and_return_conditional_losses_455476) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4b_se_reduce_layer_call_and_return_conditional_losses_417354) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4b_se_reduce_layer_call_and_return_conditional_losses_452213) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4b_activation_layer_call_and_return_conditional_losses_452173) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block1a_se_reduce_layer_call_and_return_conditional_losses_415571) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3b_se_reduce_layer_call_and_return_conditional_losses_451514) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5a_activation_layer_call_and_return_conditional_losses_417971) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6c_se_reduce_layer_call_and_return_conditional_losses_454730) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3b_se_reduce_layer_call_and_return_conditional_losses_416742) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2a_se_reduce_layer_call_and_return_conditional_losses_450489) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3a_activation_layer_call_and_return_conditional_losses_451148) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5b_expand_activation_layer_call_and_return_conditional_losses_418194) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3a_se_reduce_layer_call_and_return_conditional_losses_416463) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_efficientnetb0_layer_call_and_return_conditional_losses_429711) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_efficientnetb0_layer_call_and_return_conditional_losses_443351) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5c_expand_activation_layer_call_and_return_conditional_losses_418526) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5b_activation_layer_call_and_return_conditional_losses_453245) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3a_activation_layer_call_and_return_conditional_losses_416416) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_efficientnetb0_layer_call_and_return_conditional_losses_428089) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2b_expand_activation_layer_call_and_return_conditional_losses_416027) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6a_expand_activation_layer_call_and_return_conditional_losses_453912) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4c_activation_layer_call_and_return_conditional_losses_452546) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block7a_se_reduce_layer_call_and_return_conditional_losses_420237) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5c_se_reduce_layer_call_and_return_conditional_losses_418629) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3a_expand_activation_layer_call_and_return_conditional_losses_416359) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3b_expand_activation_layer_call_and_return_conditional_losses_451395) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6c_activation_layer_call_and_return_conditional_losses_454690) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6d_se_reduce_layer_call_and_return_conditional_losses_419905) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6c_activation_layer_call_and_return_conditional_losses_419526) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5b_se_reduce_layer_call_and_return_conditional_losses_418297) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4b_expand_activation_layer_call_and_return_conditional_losses_452094) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference__wrapped_model_408990) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5c_activation_layer_call_and_return_conditional_losses_453618) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6d_expand_activation_layer_call_and_return_conditional_losses_454984) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2b_expand_activation_layer_call_and_return_conditional_losses_450696) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6a_expand_activation_layer_call_and_return_conditional_losses_418858) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_stem_activation_layer_call_and_return_conditional_losses_450044) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5b_activation_layer_call_and_return_conditional_losses_418250) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6a_activation_layer_call_and_return_conditional_losses_453991) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5b_se_reduce_layer_call_and_return_conditional_losses_453285) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4a_expand_activation_layer_call_and_return_conditional_losses_416971) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_top_activation_layer_call_and_return_conditional_losses_455683) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2a_se_reduce_layer_call_and_return_conditional_losses_415851) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5b_expand_activation_layer_call_and_return_conditional_losses_453166) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_top_activation_layer_call_and_return_conditional_losses_420413) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block1a_activation_layer_call_and_return_conditional_losses_450123) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4a_se_reduce_layer_call_and_return_conditional_losses_417075) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5a_expand_activation_layer_call_and_return_conditional_losses_452840) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4b_activation_layer_call_and_return_conditional_losses_417307) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6d_activation_layer_call_and_return_conditional_losses_455063) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6d_expand_activation_layer_call_and_return_conditional_losses_419802) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6d_activation_layer_call_and_return_conditional_losses_419858) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block5a_se_reduce_layer_call_and_return_conditional_losses_452959) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3a_expand_activation_layer_call_and_return_conditional_losses_451069) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2a_expand_activation_layer_call_and_return_conditional_losses_450370) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6b_expand_activation_layer_call_and_return_conditional_losses_419138) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6b_activation_layer_call_and_return_conditional_losses_419194) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6c_se_reduce_layer_call_and_return_conditional_losses_419573) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block7a_expand_activation_layer_call_and_return_conditional_losses_420134) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4a_activation_layer_call_and_return_conditional_losses_417028) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6c_expand_activation_layer_call_and_return_conditional_losses_454611) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block3b_expand_activation_layer_call_and_return_conditional_losses_416639) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4c_se_reduce_layer_call_and_return_conditional_losses_417686) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block4b_expand_activation_layer_call_and_return_conditional_losses_417251) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6d_se_reduce_layer_call_and_return_conditional_losses_455103) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2b_se_reduce_layer_call_and_return_conditional_losses_450815) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block2b_se_reduce_layer_call_and_return_conditional_losses_416130) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6b_activation_layer_call_and_return_conditional_losses_454317) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6a_se_reduce_layer_call_and_return_conditional_losses_418962) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.
WARNING:absl:Importing a function (__inference_block6b_se_reduce_layer_call_and_return_conditional_losses_419241) with ops with unsaved custom gradients. Will likely fail if a gradient is requested.

To make sure our loaded model is indead a trained model, let's evaluate its performance on the test dataset.

In [23]:
# Check to see if loaded model is a trained model
loaded_loss, loaded_accuracy = model.evaluate(test_data)
loaded_loss, loaded_accuracy
790/790 [==============================] - 62s 76ms/step - loss: 1.8027 - accuracy: 0.6078
Out[23]:
(1.8027206659317017, 0.6077623963356018)

Wonderful! It looks like our loaded model is performing just as well as it was before we saved it. Let's make some predictions.

Making predictions with our trained model¶

To evaluate our trained model, we need to make some predictions with it and then compare those predictions to the test dataset.

Because the model has never seen the test dataset, this should give us an indication of how the model will perform in the real world on data similar to what it has been trained on.

To make predictions with our trained model, we can use the predict() method passing it the test data.

Since our data is multi-class, doing this will return a prediction probably tensor for each sample.

In other words, every time the trained model see's an image it will compare it to all of the patterns it learned during training and return an output for every class (all 101 of them) of how likely the image is to be that class.

In [24]:
# Make predictions with model
pred_probs = model.predict(test_data, verbose=1) # set verbosity to see how long it will take 
790/790 [==============================] - 64s 79ms/step

We just passed all of the test images to our model and asked it to make a prediction on what food it thinks is in each.

So if we had 25250 images in the test dataset, how many predictions do you think we should have?

In [25]:
# How many predictions are there?
len(pred_probs)
Out[25]:
25250

And if each image could be one of 101 classes, how many predictions do you think we'll have for each image?

In [26]:
# What's the shape of our predictions?
pred_probs.shape
Out[26]:
(25250, 101)

What we've got is often referred to as a predictions probability tensor (or array).

Let's see what the first 10 look like.

In [27]:
# How do they look?
pred_probs[:10]
Out[27]:
array([[5.9541997e-02, 3.5742044e-06, 4.1377008e-02, ..., 1.4138679e-09,
        8.3530824e-05, 3.0897509e-03],
       [9.6401680e-01, 1.3753272e-09, 8.4780820e-04, ..., 5.4287048e-05,
        7.8362204e-12, 9.8466024e-10],
       [9.5925868e-01, 3.2533648e-05, 1.4866963e-03, ..., 7.1891270e-07,
        5.4397265e-07, 4.0275925e-05],
       ...,
       [4.7313246e-01, 1.2931228e-07, 1.4805609e-03, ..., 5.9750048e-04,
        6.6969005e-05, 2.3469245e-05],
       [4.4571832e-02, 4.7265476e-07, 1.2258517e-01, ..., 6.3498501e-06,
        7.5318512e-06, 3.6778776e-03],
       [7.2438955e-01, 1.9249771e-09, 5.2310937e-05, ..., 1.2291373e-03,
        1.5792714e-09, 9.6395648e-05]], dtype=float32)

Alright, it seems like we've got a bunch of tensors of really small numbers, how about we zoom into one of them?

In [28]:
# We get one prediction probability per class
print(f"Number of prediction probabilities for sample 0: {len(pred_probs[0])}")
print(f"What prediction probability sample 0 looks like:\n {pred_probs[0]}")
print(f"The class with the highest predicted probability by the model for sample 0: {pred_probs[0].argmax()}")
Number of prediction probabilities for sample 0: 101
What prediction probability sample 0 looks like:
 [5.9541997e-02 3.5742044e-06 4.1377008e-02 1.0660534e-09 8.1614111e-09
 8.6639478e-09 8.0927191e-07 8.5652403e-07 1.9859070e-05 8.0977674e-07
 3.1727800e-09 9.8673911e-07 2.8532185e-04 7.8049184e-10 7.4230076e-04
 3.8916409e-05 6.4740357e-06 2.4977301e-06 3.7891128e-05 2.0678326e-07
 1.5538435e-05 8.1507017e-07 2.6230514e-06 2.0010684e-07 8.3827518e-07
 5.4216030e-06 3.7390816e-06 1.3150530e-08 2.7761480e-03 2.8051860e-05
 6.8562017e-10 2.5574853e-05 1.6688880e-04 7.6406842e-10 4.0452872e-04
 1.3150632e-08 1.7957391e-06 1.4448174e-06 2.3062943e-02 8.2466846e-07
 8.5365838e-07 1.7138658e-06 7.0525025e-06 1.8402130e-08 2.8553373e-07
 7.9483443e-06 2.0681568e-06 1.8525114e-07 3.3619767e-08 3.1522580e-04
 1.0410922e-05 8.5448272e-07 8.4741855e-01 1.0555444e-05 4.4094620e-07
 3.7404177e-05 3.5306122e-05 3.2489035e-05 6.7314730e-05 1.2852565e-08
 2.6219704e-10 1.0318108e-05 8.5744112e-05 1.0569904e-06 2.1293351e-06
 3.7637546e-05 7.5972999e-08 2.5340536e-04 9.2905668e-07 1.2598113e-04
 6.2621534e-06 1.2458702e-08 4.0519491e-05 6.8728106e-08 1.2546303e-06
 5.2887280e-08 7.5424914e-08 7.5398340e-05 7.7540433e-05 6.4025880e-07
 9.9033480e-07 2.2225879e-05 1.5013875e-05 1.4038460e-07 1.2232531e-05
 1.9044764e-02 4.9999646e-05 4.6226128e-06 1.5388179e-07 3.3824128e-07
 3.9228252e-09 1.6563671e-07 8.1320752e-05 4.8965157e-06 2.4068302e-07
 2.3124045e-05 3.1040644e-04 3.1379939e-05 1.4138679e-09 8.3530824e-05
 3.0897509e-03]
The class with the highest predicted probability by the model for sample 0: 52

As we discussed before, for each image tensor we pass to our model, because of the number of output neurons and activation function in the last layer (layers.Dense(len(train_data_all_10_percent.class_names), activation="softmax"), it outputs a prediction probability between 0 and 1 for all each of the 101 classes.

And the index of the highest prediction probability can be considered what the model thinks is the most likely label. Similarly, the lower prediction probaiblity value, the less the model thinks that the target image is that specific class.

🔑 Note: Due to the nature of the softmax activation function, the sum of each of the prediction probabilities for a single sample will be 1 (or at least very close to 1). E.g. pred_probs[0].sum() = 1.

We can find the index of the maximum value in each prediction probability tensor using the argmax() method.

In [29]:
# Get the class predicitons of each label
pred_classes = pred_probs.argmax(axis=1)

# How do they look?
pred_classes[:10]
Out[29]:
array([52,  0,  0, 80, 79, 61, 29,  0, 85,  0])

Beautiful! We've now got the predicted class index for each of the samples in our test dataset.

We'll be able to compare these to the test dataset labels to further evaluate our model.

To get the test dataset labels we can unravel our test_data object (which is in the form of a tf.data.Dataset) using the unbatch() method.

Doing this will give us access to the images and labels in the test dataset. Since the labels are in one-hot encoded format, we'll take use the argmax() method to return the index of the label.

🔑 Note: This unravelling is why we shuffle=False when creating the test data object. Otherwise, whenever we loaded the test dataset (like when making predictions), it would be shuffled every time, meaning if we tried to compare our predictions to the labels, they would be in different orders.

In [30]:
# Note: This might take a minute or so due to unravelling 790 batches
y_labels = []
for images, labels in test_data.unbatch(): # unbatch the test data and get images and labels
  y_labels.append(labels.numpy().argmax()) # append the index which has the largest value, the "1" in an array of 0s and 1s (labels are one-hot)
y_labels[:10] # check what they look like (unshuffled)
Out[30]:
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

Nice! Since test_data isn't shuffled, the y_labels array comes back in the same order as the pred_classes array.

The final check is to see how many labels we've got.

In [31]:
# How many labels are there? (should be the same as how many prediction probabilities we have)
len(y_labels)
Out[31]:
25250

As expected, the number of labels matches the number of images we've got. Time to compare our model's predictions with the ground truth labels.

Evaluating our models predictions¶

A very simple evaluation is to use Scikit-Learn's accuracy_score() function which compares truth labels to predicted labels and returns an accuracy score.

If we've created our y_labels and pred_classes arrays correctly, this should return the same accuracy value (or at least very close) as the evaluate() method we used earlier.

In [32]:
# Get accuracy score by comparing predicted classes to ground truth labels
from sklearn.metrics import accuracy_score
sklearn_accuracy = accuracy_score(y_labels, pred_classes)
sklearn_accuracy
Out[32]:
0.6077623762376237
In [33]:
# Does the evaluate method compare to the Scikit-Learn measured accuracy?
import numpy as np
print(f"Close? {np.isclose(loaded_accuracy, sklearn_accuracy)} | Difference: {loaded_accuracy - sklearn_accuracy}")
Close? True | Difference: 2.0097978059574473e-08

Okay, it looks like our pred_classes array and y_labels arrays are in the right orders.

How about we get a little bit more visual with a confusion matrix?

To do so, we'll use our make_confusion_matrix function we created in a previous notebook.

In [34]:
# We'll import our make_confusion_matrix function from https://github.com/mrdbourke/tensorflow-deep-learning/blob/main/extras/helper_functions.py 
# But if you run it out of the box, it doesn't really work for 101 classes...
# the cell below adds a little functionality to make it readable.
from helper_functions import make_confusion_matrix
In [35]:
# Note: The following confusion matrix code is a remix of Scikit-Learn's 
# plot_confusion_matrix function - https://scikit-learn.org/stable/modules/generated/sklearn.metrics.plot_confusion_matrix.html
import itertools
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix

# Our function needs a different name to sklearn's plot_confusion_matrix
def make_confusion_matrix(y_true, y_pred, classes=None, figsize=(10, 10), text_size=15, norm=False, savefig=False): 
  """Makes a labelled confusion matrix comparing predictions and ground truth labels.

  If classes is passed, confusion matrix will be labelled, if not, integer class values
  will be used.

  Args:
    y_true: Array of truth labels (must be same shape as y_pred).
    y_pred: Array of predicted labels (must be same shape as y_true).
    classes: Array of class labels (e.g. string form). If `None`, integer labels are used.
    figsize: Size of output figure (default=(10, 10)).
    text_size: Size of output figure text (default=15).
    norm: normalize values or not (default=False).
    savefig: save confusion matrix to file (default=False).
  
  Returns:
    A labelled confusion matrix plot comparing y_true and y_pred.

  Example usage:
    make_confusion_matrix(y_true=test_labels, # ground truth test labels
                          y_pred=y_preds, # predicted labels
                          classes=class_names, # array of class label names
                          figsize=(15, 15),
                          text_size=10)
  """  
  # Create the confustion matrix
  cm = confusion_matrix(y_true, y_pred)
  cm_norm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] # normalize it
  n_classes = cm.shape[0] # find the number of classes we're dealing with

  # Plot the figure and make it pretty
  fig, ax = plt.subplots(figsize=figsize)
  cax = ax.matshow(cm, cmap=plt.cm.Blues) # colors will represent how 'correct' a class is, darker == better
  fig.colorbar(cax)

  # Are there a list of classes?
  if classes:
    labels = classes
  else:
    labels = np.arange(cm.shape[0])
  
  # Label the axes
  ax.set(title="Confusion Matrix",
         xlabel="Predicted label",
         ylabel="True label",
         xticks=np.arange(n_classes), # create enough axis slots for each class
         yticks=np.arange(n_classes), 
         xticklabels=labels, # axes will labeled with class names (if they exist) or ints
         yticklabels=labels)
  
  # Make x-axis labels appear on bottom
  ax.xaxis.set_label_position("bottom")
  ax.xaxis.tick_bottom()

  ### Added: Rotate xticks for readability & increase font size (required due to such a large confusion matrix)
  plt.xticks(rotation=70, fontsize=text_size)
  plt.yticks(fontsize=text_size)

  # Set the threshold for different colors
  threshold = (cm.max() + cm.min()) / 2.

  # Plot the text on each cell
  for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
    if norm:
      plt.text(j, i, f"{cm[i, j]} ({cm_norm[i, j]*100:.1f}%)",
              horizontalalignment="center",
              color="white" if cm[i, j] > threshold else "black",
              size=text_size)
    else:
      plt.text(j, i, f"{cm[i, j]}",
              horizontalalignment="center",
              color="white" if cm[i, j] > threshold else "black",
              size=text_size)

  # Save the figure to the current working directory
  if savefig:
    fig.savefig("confusion_matrix.png")

Right now our predictions and truth labels are in the form of integers, however, they'll be much easier to understand if we get their actual names. We can do so using the class_names attribute on our test_data object.

In [36]:
# Get the class names
class_names = test_data.class_names
class_names
Out[36]:
['apple_pie',
 'baby_back_ribs',
 'baklava',
 'beef_carpaccio',
 'beef_tartare',
 'beet_salad',
 'beignets',
 'bibimbap',
 'bread_pudding',
 'breakfast_burrito',
 'bruschetta',
 'caesar_salad',
 'cannoli',
 'caprese_salad',
 'carrot_cake',
 'ceviche',
 'cheese_plate',
 'cheesecake',
 'chicken_curry',
 'chicken_quesadilla',
 'chicken_wings',
 'chocolate_cake',
 'chocolate_mousse',
 'churros',
 'clam_chowder',
 'club_sandwich',
 'crab_cakes',
 'creme_brulee',
 'croque_madame',
 'cup_cakes',
 'deviled_eggs',
 'donuts',
 'dumplings',
 'edamame',
 'eggs_benedict',
 'escargots',
 'falafel',
 'filet_mignon',
 'fish_and_chips',
 'foie_gras',
 'french_fries',
 'french_onion_soup',
 'french_toast',
 'fried_calamari',
 'fried_rice',
 'frozen_yogurt',
 'garlic_bread',
 'gnocchi',
 'greek_salad',
 'grilled_cheese_sandwich',
 'grilled_salmon',
 'guacamole',
 'gyoza',
 'hamburger',
 'hot_and_sour_soup',
 'hot_dog',
 'huevos_rancheros',
 'hummus',
 'ice_cream',
 'lasagna',
 'lobster_bisque',
 'lobster_roll_sandwich',
 'macaroni_and_cheese',
 'macarons',
 'miso_soup',
 'mussels',
 'nachos',
 'omelette',
 'onion_rings',
 'oysters',
 'pad_thai',
 'paella',
 'pancakes',
 'panna_cotta',
 'peking_duck',
 'pho',
 'pizza',
 'pork_chop',
 'poutine',
 'prime_rib',
 'pulled_pork_sandwich',
 'ramen',
 'ravioli',
 'red_velvet_cake',
 'risotto',
 'samosa',
 'sashimi',
 'scallops',
 'seaweed_salad',
 'shrimp_and_grits',
 'spaghetti_bolognese',
 'spaghetti_carbonara',
 'spring_rolls',
 'steak',
 'strawberry_shortcake',
 'sushi',
 'tacos',
 'takoyaki',
 'tiramisu',
 'tuna_tartare',
 'waffles']

101 class names and 25250 predictions and ground truth labels ready to go! Looks like our confusion matrix is going to be a big one!

In [37]:
# Plot a confusion matrix with all 25250 predictions, ground truth labels and 101 classes
make_confusion_matrix(y_true=y_labels,
                      y_pred=pred_classes,
                      classes=class_names,
                      figsize=(100, 100),
                      text_size=20,
                      norm=False,
                      savefig=True)

Woah! Now that's a big confusion matrix. It may look a little daunting at first but after zooming in a little, we can see how it gives us insight into which classes its getting "confused" on.

The good news is, the majority of the predictions are right down the top left to bottom right diagonal, meaning they're correct.

It looks like the model gets most confused on classes which look visualually similar, such as predicting filet_mignon for instances of pork_chop and chocolate_cake for instances of tiramisu.

Since we're working on a classification problem, we can further evaluate our model's predictions using Scikit-Learn's classification_report() function.

In [38]:
from sklearn.metrics import classification_report
print(classification_report(y_labels, pred_classes))
              precision    recall  f1-score   support

           0       0.29      0.20      0.24       250
           1       0.51      0.69      0.59       250
           2       0.56      0.65      0.60       250
           3       0.74      0.53      0.62       250
           4       0.73      0.43      0.54       250
           5       0.34      0.54      0.42       250
           6       0.67      0.79      0.72       250
           7       0.82      0.76      0.79       250
           8       0.40      0.37      0.39       250
           9       0.62      0.44      0.51       250
          10       0.62      0.42      0.50       250
          11       0.84      0.49      0.62       250
          12       0.52      0.74      0.61       250
          13       0.56      0.60      0.58       250
          14       0.56      0.59      0.57       250
          15       0.44      0.32      0.37       250
          16       0.45      0.75      0.57       250
          17       0.37      0.51      0.43       250
          18       0.43      0.60      0.50       250
          19       0.68      0.60      0.64       250
          20       0.68      0.75      0.71       250
          21       0.35      0.64      0.45       250
          22       0.30      0.37      0.33       250
          23       0.66      0.77      0.71       250
          24       0.83      0.72      0.77       250
          25       0.76      0.71      0.73       250
          26       0.51      0.42      0.46       250
          27       0.78      0.72      0.75       250
          28       0.70      0.69      0.69       250
          29       0.70      0.68      0.69       250
          30       0.92      0.63      0.75       250
          31       0.78      0.70      0.74       250
          32       0.75      0.83      0.79       250
          33       0.89      0.98      0.94       250
          34       0.68      0.78      0.72       250
          35       0.78      0.66      0.72       250
          36       0.53      0.56      0.55       250
          37       0.30      0.55      0.39       250
          38       0.78      0.63      0.69       250
          39       0.27      0.33      0.30       250
          40       0.72      0.81      0.76       250
          41       0.81      0.62      0.70       250
          42       0.50      0.58      0.54       250
          43       0.75      0.60      0.67       250
          44       0.74      0.45      0.56       250
          45       0.77      0.85      0.81       250
          46       0.81      0.46      0.58       250
          47       0.44      0.49      0.46       250
          48       0.45      0.81      0.58       250
          49       0.50      0.44      0.47       250
          50       0.54      0.39      0.46       250
          51       0.71      0.86      0.78       250
          52       0.51      0.77      0.61       250
          53       0.67      0.68      0.68       250
          54       0.88      0.75      0.81       250
          55       0.86      0.69      0.76       250
          56       0.56      0.24      0.34       250
          57       0.62      0.45      0.52       250
          58       0.68      0.58      0.62       250
          59       0.70      0.37      0.49       250
          60       0.83      0.59      0.69       250
          61       0.54      0.81      0.65       250
          62       0.72      0.49      0.58       250
          63       0.94      0.86      0.90       250
          64       0.78      0.85      0.81       250
          65       0.82      0.82      0.82       250
          66       0.69      0.32      0.44       250
          67       0.41      0.58      0.48       250
          68       0.90      0.78      0.83       250
          69       0.84      0.82      0.83       250
          70       0.62      0.83      0.71       250
          71       0.81      0.46      0.59       250
          72       0.64      0.65      0.65       250
          73       0.51      0.44      0.47       250
          74       0.72      0.61      0.66       250
          75       0.84      0.90      0.87       250
          76       0.78      0.78      0.78       250
          77       0.36      0.27      0.31       250
          78       0.79      0.74      0.76       250
          79       0.44      0.81      0.57       250
          80       0.57      0.60      0.59       250
          81       0.65      0.70      0.68       250
          82       0.38      0.31      0.34       250
          83       0.58      0.80      0.67       250
          84       0.61      0.38      0.47       250
          85       0.44      0.74      0.55       250
          86       0.71      0.86      0.78       250
          87       0.41      0.39      0.40       250
          88       0.83      0.80      0.81       250
          89       0.71      0.31      0.43       250
          90       0.92      0.69      0.79       250
          91       0.83      0.87      0.85       250
          92       0.68      0.65      0.67       250
          93       0.31      0.38      0.34       250
          94       0.61      0.54      0.57       250
          95       0.74      0.61      0.67       250
          96       0.56      0.29      0.38       250
          97       0.45      0.74      0.56       250
          98       0.47      0.33      0.39       250
          99       0.52      0.27      0.35       250
         100       0.59      0.70      0.64       250

    accuracy                           0.61     25250
   macro avg       0.63      0.61      0.61     25250
weighted avg       0.63      0.61      0.61     25250

The classification_report() outputs the precision, recall and f1-score's per class.

A reminder:

  • Precision - Proportion of true positives over total number of samples. Higher precision leads to less false positives (model predicts 1 when it should've been 0).
  • Recall - Proportion of true positives over total number of true positives and false negatives (model predicts 0 when it should've been 1). Higher recall leads to less false negatives.
  • F1 score - Combines precision and recall into one metric. 1 is best, 0 is worst.

The above output is helpful but with so many classes, it's a bit hard to understand.

Let's see if we make it easier with the help of a visualization.

First, we'll get the output of classification_report() as a dictionary by setting output_dict=True.

In [39]:
# Get a dictionary of the classification report
classification_report_dict = classification_report(y_labels, pred_classes, output_dict=True)
classification_report_dict
Out[39]:
{'0': {'f1-score': 0.24056603773584903,
  'precision': 0.29310344827586204,
  'recall': 0.204,
  'support': 250},
 '1': {'f1-score': 0.5864406779661017,
  'precision': 0.5088235294117647,
  'recall': 0.692,
  'support': 250},
 '10': {'f1-score': 0.5047619047619047,
  'precision': 0.6235294117647059,
  'recall': 0.424,
  'support': 250},
 '100': {'f1-score': 0.641025641025641,
  'precision': 0.5912162162162162,
  'recall': 0.7,
  'support': 250},
 '11': {'f1-score': 0.6161616161616161,
  'precision': 0.8356164383561644,
  'recall': 0.488,
  'support': 250},
 '12': {'f1-score': 0.6105610561056106,
  'precision': 0.5196629213483146,
  'recall': 0.74,
  'support': 250},
 '13': {'f1-score': 0.5775193798449612,
  'precision': 0.5601503759398496,
  'recall': 0.596,
  'support': 250},
 '14': {'f1-score': 0.574757281553398,
  'precision': 0.5584905660377358,
  'recall': 0.592,
  'support': 250},
 '15': {'f1-score': 0.36744186046511623,
  'precision': 0.4388888888888889,
  'recall': 0.316,
  'support': 250},
 '16': {'f1-score': 0.5654135338345864,
  'precision': 0.4530120481927711,
  'recall': 0.752,
  'support': 250},
 '17': {'f1-score': 0.42546063651591287,
  'precision': 0.3659942363112392,
  'recall': 0.508,
  'support': 250},
 '18': {'f1-score': 0.5008403361344538,
  'precision': 0.4318840579710145,
  'recall': 0.596,
  'support': 250},
 '19': {'f1-score': 0.6411889596602972,
  'precision': 0.6832579185520362,
  'recall': 0.604,
  'support': 250},
 '2': {'f1-score': 0.6022304832713754,
  'precision': 0.5625,
  'recall': 0.648,
  'support': 250},
 '20': {'f1-score': 0.7123809523809523,
  'precision': 0.68,
  'recall': 0.748,
  'support': 250},
 '21': {'f1-score': 0.45261669024045265,
  'precision': 0.350109409190372,
  'recall': 0.64,
  'support': 250},
 '22': {'f1-score': 0.3291592128801431,
  'precision': 0.2977346278317152,
  'recall': 0.368,
  'support': 250},
 '23': {'f1-score': 0.7134935304990757,
  'precision': 0.6632302405498282,
  'recall': 0.772,
  'support': 250},
 '24': {'f1-score': 0.7708779443254817,
  'precision': 0.8294930875576036,
  'recall': 0.72,
  'support': 250},
 '25': {'f1-score': 0.734020618556701,
  'precision': 0.7574468085106383,
  'recall': 0.712,
  'support': 250},
 '26': {'f1-score': 0.4625550660792952,
  'precision': 0.5147058823529411,
  'recall': 0.42,
  'support': 250},
 '27': {'f1-score': 0.7494824016563146,
  'precision': 0.776824034334764,
  'recall': 0.724,
  'support': 250},
 '28': {'f1-score': 0.6935483870967742,
  'precision': 0.6991869918699187,
  'recall': 0.688,
  'support': 250},
 '29': {'f1-score': 0.6910569105691057,
  'precision': 0.7024793388429752,
  'recall': 0.68,
  'support': 250},
 '3': {'f1-score': 0.616822429906542,
  'precision': 0.7415730337078652,
  'recall': 0.528,
  'support': 250},
 '30': {'f1-score': 0.7476190476190476,
  'precision': 0.9235294117647059,
  'recall': 0.628,
  'support': 250},
 '31': {'f1-score': 0.7357293868921776,
  'precision': 0.7802690582959642,
  'recall': 0.696,
  'support': 250},
 '32': {'f1-score': 0.7855787476280836,
  'precision': 0.7472924187725631,
  'recall': 0.828,
  'support': 250},
 '33': {'f1-score': 0.9371428571428572,
  'precision': 0.8945454545454545,
  'recall': 0.984,
  'support': 250},
 '34': {'f1-score': 0.7238805970149255,
  'precision': 0.6783216783216783,
  'recall': 0.776,
  'support': 250},
 '35': {'f1-score': 0.715835140997831,
  'precision': 0.7819905213270142,
  'recall': 0.66,
  'support': 250},
 '36': {'f1-score': 0.5475728155339805,
  'precision': 0.5320754716981132,
  'recall': 0.564,
  'support': 250},
 '37': {'f1-score': 0.3870056497175141,
  'precision': 0.29912663755458513,
  'recall': 0.548,
  'support': 250},
 '38': {'f1-score': 0.6946902654867257,
  'precision': 0.7772277227722773,
  'recall': 0.628,
  'support': 250},
 '39': {'f1-score': 0.29749103942652333,
  'precision': 0.2694805194805195,
  'recall': 0.332,
  'support': 250},
 '4': {'f1-score': 0.544080604534005,
  'precision': 0.7346938775510204,
  'recall': 0.432,
  'support': 250},
 '40': {'f1-score': 0.7622641509433963,
  'precision': 0.7214285714285714,
  'recall': 0.808,
  'support': 250},
 '41': {'f1-score': 0.7029478458049886,
  'precision': 0.8115183246073299,
  'recall': 0.62,
  'support': 250},
 '42': {'f1-score': 0.537037037037037,
  'precision': 0.5,
  'recall': 0.58,
  'support': 250},
 '43': {'f1-score': 0.6651884700665188,
  'precision': 0.746268656716418,
  'recall': 0.6,
  'support': 250},
 '44': {'f1-score': 0.5586034912718205,
  'precision': 0.7417218543046358,
  'recall': 0.448,
  'support': 250},
 '45': {'f1-score': 0.8114285714285714,
  'precision': 0.7745454545454545,
  'recall': 0.852,
  'support': 250},
 '46': {'f1-score': 0.5831202046035805,
  'precision': 0.8085106382978723,
  'recall': 0.456,
  'support': 250},
 '47': {'f1-score': 0.4641509433962264,
  'precision': 0.4392857142857143,
  'recall': 0.492,
  'support': 250},
 '48': {'f1-score': 0.577524893314367,
  'precision': 0.4481236203090508,
  'recall': 0.812,
  'support': 250},
 '49': {'f1-score': 0.47234042553191485,
  'precision': 0.5045454545454545,
  'recall': 0.444,
  'support': 250},
 '5': {'f1-score': 0.41860465116279066,
  'precision': 0.34177215189873417,
  'recall': 0.54,
  'support': 250},
 '50': {'f1-score': 0.45581395348837206,
  'precision': 0.5444444444444444,
  'recall': 0.392,
  'support': 250},
 '51': {'f1-score': 0.7783783783783783,
  'precision': 0.7081967213114754,
  'recall': 0.864,
  'support': 250},
 '52': {'f1-score': 0.6124401913875598,
  'precision': 0.5092838196286472,
  'recall': 0.768,
  'support': 250},
 '53': {'f1-score': 0.6759443339960238,
  'precision': 0.6719367588932806,
  'recall': 0.68,
  'support': 250},
 '54': {'f1-score': 0.8103448275862069,
  'precision': 0.8785046728971962,
  'recall': 0.752,
  'support': 250},
 '55': {'f1-score': 0.7644444444444444,
  'precision': 0.86,
  'recall': 0.688,
  'support': 250},
 '56': {'f1-score': 0.3398328690807799,
  'precision': 0.5596330275229358,
  'recall': 0.244,
  'support': 250},
 '57': {'f1-score': 0.5209302325581396,
  'precision': 0.6222222222222222,
  'recall': 0.448,
  'support': 250},
 '58': {'f1-score': 0.6233766233766233,
  'precision': 0.6792452830188679,
  'recall': 0.576,
  'support': 250},
 '59': {'f1-score': 0.486910994764398,
  'precision': 0.7045454545454546,
  'recall': 0.372,
  'support': 250},
 '6': {'f1-score': 0.7229357798165138,
  'precision': 0.6677966101694915,
  'recall': 0.788,
  'support': 250},
 '60': {'f1-score': 0.6885245901639344,
  'precision': 0.8305084745762712,
  'recall': 0.588,
  'support': 250},
 '61': {'f1-score': 0.6495176848874598,
  'precision': 0.543010752688172,
  'recall': 0.808,
  'support': 250},
 '62': {'f1-score': 0.5823389021479712,
  'precision': 0.7218934911242604,
  'recall': 0.488,
  'support': 250},
 '63': {'f1-score': 0.895397489539749,
  'precision': 0.9385964912280702,
  'recall': 0.856,
  'support': 250},
 '64': {'f1-score': 0.8129770992366412,
  'precision': 0.7773722627737226,
  'recall': 0.852,
  'support': 250},
 '65': {'f1-score': 0.82, 'precision': 0.82, 'recall': 0.82, 'support': 250},
 '66': {'f1-score': 0.44141689373297005,
  'precision': 0.6923076923076923,
  'recall': 0.324,
  'support': 250},
 '67': {'f1-score': 0.47840531561461797,
  'precision': 0.4090909090909091,
  'recall': 0.576,
  'support': 250},
 '68': {'f1-score': 0.832618025751073,
  'precision': 0.8981481481481481,
  'recall': 0.776,
  'support': 250},
 '69': {'f1-score': 0.8340080971659919,
  'precision': 0.8442622950819673,
  'recall': 0.824,
  'support': 250},
 '7': {'f1-score': 0.7908902691511386,
  'precision': 0.8197424892703863,
  'recall': 0.764,
  'support': 250},
 '70': {'f1-score': 0.7101200686106347,
  'precision': 0.6216216216216216,
  'recall': 0.828,
  'support': 250},
 '71': {'f1-score': 0.5903307888040712,
  'precision': 0.8111888111888111,
  'recall': 0.464,
  'support': 250},
 '72': {'f1-score': 0.6468253968253969,
  'precision': 0.6417322834645669,
  'recall': 0.652,
  'support': 250},
 '73': {'f1-score': 0.4743589743589744,
  'precision': 0.5091743119266054,
  'recall': 0.444,
  'support': 250},
 '74': {'f1-score': 0.658008658008658,
  'precision': 0.7169811320754716,
  'recall': 0.608,
  'support': 250},
 '75': {'f1-score': 0.8665377176015473,
  'precision': 0.8389513108614233,
  'recall': 0.896,
  'support': 250},
 '76': {'f1-score': 0.7808764940239045,
  'precision': 0.7777777777777778,
  'recall': 0.784,
  'support': 250},
 '77': {'f1-score': 0.30875576036866365,
  'precision': 0.3641304347826087,
  'recall': 0.268,
  'support': 250},
 '78': {'f1-score': 0.7603305785123966,
  'precision': 0.7863247863247863,
  'recall': 0.736,
  'support': 250},
 '79': {'f1-score': 0.571830985915493,
  'precision': 0.44130434782608696,
  'recall': 0.812,
  'support': 250},
 '8': {'f1-score': 0.3866943866943867,
  'precision': 0.4025974025974026,
  'recall': 0.372,
  'support': 250},
 '80': {'f1-score': 0.5870841487279843,
  'precision': 0.5747126436781609,
  'recall': 0.6,
  'support': 250},
 '81': {'f1-score': 0.6756756756756757,
  'precision': 0.6529850746268657,
  'recall': 0.7,
  'support': 250},
 '82': {'f1-score': 0.34285714285714286,
  'precision': 0.3804878048780488,
  'recall': 0.312,
  'support': 250},
 '83': {'f1-score': 0.6711409395973154,
  'precision': 0.5780346820809249,
  'recall': 0.8,
  'support': 250},
 '84': {'f1-score': 0.4653465346534653,
  'precision': 0.6103896103896104,
  'recall': 0.376,
  'support': 250},
 '85': {'f1-score': 0.5525525525525525,
  'precision': 0.4423076923076923,
  'recall': 0.736,
  'support': 250},
 '86': {'f1-score': 0.7783783783783783,
  'precision': 0.7081967213114754,
  'recall': 0.864,
  'support': 250},
 '87': {'f1-score': 0.3975409836065574,
  'precision': 0.40756302521008403,
  'recall': 0.388,
  'support': 250},
 '88': {'f1-score': 0.8130081300813008,
  'precision': 0.8264462809917356,
  'recall': 0.8,
  'support': 250},
 '89': {'f1-score': 0.4301675977653631,
  'precision': 0.7129629629629629,
  'recall': 0.308,
  'support': 250},
 '9': {'f1-score': 0.5117370892018779,
  'precision': 0.6193181818181818,
  'recall': 0.436,
  'support': 250},
 '90': {'f1-score': 0.7881548974943051,
  'precision': 0.9153439153439153,
  'recall': 0.692,
  'support': 250},
 '91': {'f1-score': 0.84765625,
  'precision': 0.8282442748091603,
  'recall': 0.868,
  'support': 250},
 '92': {'f1-score': 0.6652977412731006,
  'precision': 0.6835443037974683,
  'recall': 0.648,
  'support': 250},
 '93': {'f1-score': 0.34234234234234234,
  'precision': 0.3114754098360656,
  'recall': 0.38,
  'support': 250},
 '94': {'f1-score': 0.5714285714285714,
  'precision': 0.6118721461187214,
  'recall': 0.536,
  'support': 250},
 '95': {'f1-score': 0.6710526315789473,
  'precision': 0.7427184466019418,
  'recall': 0.612,
  'support': 250},
 '96': {'f1-score': 0.3809523809523809,
  'precision': 0.5625,
  'recall': 0.288,
  'support': 250},
 '97': {'f1-score': 0.5644916540212443,
  'precision': 0.4547677261613692,
  'recall': 0.744,
  'support': 250},
 '98': {'f1-score': 0.3858823529411765,
  'precision': 0.4685714285714286,
  'recall': 0.328,
  'support': 250},
 '99': {'f1-score': 0.35356200527704484,
  'precision': 0.5193798449612403,
  'recall': 0.268,
  'support': 250},
 'accuracy': 0.6077623762376237,
 'macro avg': {'f1-score': 0.6061252197245781,
  'precision': 0.6328666845830312,
  'recall': 0.6077623762376237,
  'support': 25250},
 'weighted avg': {'f1-score': 0.606125219724578,
  'precision': 0.6328666845830311,
  'recall': 0.6077623762376237,
  'support': 25250}}

Alright, there's still a fair few values here, how about we narrow down?

Since the f1-score combines precision and recall in one metric, let's focus on that.

To extract it, we'll create an empty dictionary called class_f1_scores and then loop through each item in classification_report_dict, appending the class name and f1-score as the key, value pairs in class_f1_scores.

In [40]:
# Create empty dictionary
class_f1_scores = {}
# Loop through classification report items
for k, v in classification_report_dict.items():
  if k == "accuracy": # stop once we get to accuracy key
    break
  else:
    # Append class names and f1-scores to new dictionary
    class_f1_scores[class_names[int(k)]] = v["f1-score"]
class_f1_scores
Out[40]:
{'apple_pie': 0.24056603773584903,
 'baby_back_ribs': 0.5864406779661017,
 'baklava': 0.6022304832713754,
 'beef_carpaccio': 0.616822429906542,
 'beef_tartare': 0.544080604534005,
 'beet_salad': 0.41860465116279066,
 'beignets': 0.7229357798165138,
 'bibimbap': 0.7908902691511386,
 'bread_pudding': 0.3866943866943867,
 'breakfast_burrito': 0.5117370892018779,
 'bruschetta': 0.5047619047619047,
 'caesar_salad': 0.6161616161616161,
 'cannoli': 0.6105610561056106,
 'caprese_salad': 0.5775193798449612,
 'carrot_cake': 0.574757281553398,
 'ceviche': 0.36744186046511623,
 'cheese_plate': 0.5654135338345864,
 'cheesecake': 0.42546063651591287,
 'chicken_curry': 0.5008403361344538,
 'chicken_quesadilla': 0.6411889596602972,
 'chicken_wings': 0.7123809523809523,
 'chocolate_cake': 0.45261669024045265,
 'chocolate_mousse': 0.3291592128801431,
 'churros': 0.7134935304990757,
 'clam_chowder': 0.7708779443254817,
 'club_sandwich': 0.734020618556701,
 'crab_cakes': 0.4625550660792952,
 'creme_brulee': 0.7494824016563146,
 'croque_madame': 0.6935483870967742,
 'cup_cakes': 0.6910569105691057,
 'deviled_eggs': 0.7476190476190476,
 'donuts': 0.7357293868921776,
 'dumplings': 0.7855787476280836,
 'edamame': 0.9371428571428572,
 'eggs_benedict': 0.7238805970149255,
 'escargots': 0.715835140997831,
 'falafel': 0.5475728155339805,
 'filet_mignon': 0.3870056497175141,
 'fish_and_chips': 0.6946902654867257,
 'foie_gras': 0.29749103942652333,
 'french_fries': 0.7622641509433963,
 'french_onion_soup': 0.7029478458049886,
 'french_toast': 0.537037037037037,
 'fried_calamari': 0.6651884700665188,
 'fried_rice': 0.5586034912718205,
 'frozen_yogurt': 0.8114285714285714,
 'garlic_bread': 0.5831202046035805,
 'gnocchi': 0.4641509433962264,
 'greek_salad': 0.577524893314367,
 'grilled_cheese_sandwich': 0.47234042553191485,
 'grilled_salmon': 0.45581395348837206,
 'guacamole': 0.7783783783783783,
 'gyoza': 0.6124401913875598,
 'hamburger': 0.6759443339960238,
 'hot_and_sour_soup': 0.8103448275862069,
 'hot_dog': 0.7644444444444444,
 'huevos_rancheros': 0.3398328690807799,
 'hummus': 0.5209302325581396,
 'ice_cream': 0.6233766233766233,
 'lasagna': 0.486910994764398,
 'lobster_bisque': 0.6885245901639344,
 'lobster_roll_sandwich': 0.6495176848874598,
 'macaroni_and_cheese': 0.5823389021479712,
 'macarons': 0.895397489539749,
 'miso_soup': 0.8129770992366412,
 'mussels': 0.82,
 'nachos': 0.44141689373297005,
 'omelette': 0.47840531561461797,
 'onion_rings': 0.832618025751073,
 'oysters': 0.8340080971659919,
 'pad_thai': 0.7101200686106347,
 'paella': 0.5903307888040712,
 'pancakes': 0.6468253968253969,
 'panna_cotta': 0.4743589743589744,
 'peking_duck': 0.658008658008658,
 'pho': 0.8665377176015473,
 'pizza': 0.7808764940239045,
 'pork_chop': 0.30875576036866365,
 'poutine': 0.7603305785123966,
 'prime_rib': 0.571830985915493,
 'pulled_pork_sandwich': 0.5870841487279843,
 'ramen': 0.6756756756756757,
 'ravioli': 0.34285714285714286,
 'red_velvet_cake': 0.6711409395973154,
 'risotto': 0.4653465346534653,
 'samosa': 0.5525525525525525,
 'sashimi': 0.7783783783783783,
 'scallops': 0.3975409836065574,
 'seaweed_salad': 0.8130081300813008,
 'shrimp_and_grits': 0.4301675977653631,
 'spaghetti_bolognese': 0.7881548974943051,
 'spaghetti_carbonara': 0.84765625,
 'spring_rolls': 0.6652977412731006,
 'steak': 0.34234234234234234,
 'strawberry_shortcake': 0.5714285714285714,
 'sushi': 0.6710526315789473,
 'tacos': 0.3809523809523809,
 'takoyaki': 0.5644916540212443,
 'tiramisu': 0.3858823529411765,
 'tuna_tartare': 0.35356200527704484,
 'waffles': 0.641025641025641}

Looking good!

It seems like our dictionary is ordered by the class names. However, I think if we're trying to visualize different scores, it might look nicer if they were in some kind of order.

How about we turn our class_f1_scores dictionary into a pandas DataFrame and sort it in ascending fashion?

In [41]:
# Turn f1-scores into dataframe for visualization
import pandas as pd
f1_scores = pd.DataFrame({"class_name": list(class_f1_scores.keys()),
                          "f1-score": list(class_f1_scores.values())}).sort_values("f1-score", ascending=False)
f1_scores
Out[41]:
class_name f1-score
33 edamame 0.937143
63 macarons 0.895397
75 pho 0.866538
91 spaghetti_carbonara 0.847656
69 oysters 0.834008
... ... ...
56 huevos_rancheros 0.339833
22 chocolate_mousse 0.329159
77 pork_chop 0.308756
39 foie_gras 0.297491
0 apple_pie 0.240566

101 rows × 2 columns

Now we're talking! Let's finish it off with a nice horizontal bar chart.

In [42]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(12, 25))
scores = ax.barh(range(len(f1_scores)), f1_scores["f1-score"].values)
ax.set_yticks(range(len(f1_scores)))
ax.set_yticklabels(list(f1_scores["class_name"]))
ax.set_xlabel("f1-score")
ax.set_title("F1-Scores for 10 Different Classes")
ax.invert_yaxis(); # reverse the order

def autolabel(rects): # Modified version of: https://matplotlib.org/examples/api/barchart_demo.html
  """
  Attach a text label above each bar displaying its height (it's value).
  """
  for rect in rects:
    width = rect.get_width()
    ax.text(1.03*width, rect.get_y() + rect.get_height()/1.5,
            f"{width:.2f}",
            ha='center', va='bottom')

autolabel(scores)

Now that's a good looking graph! I mean, the text positioning could be improved a little but it'll do for now.

Can you see how visualizing our model's predictions gives us a completely new insight into its performance?

A few moments ago we only had an accuracy score but now we've got an indiciation of how well our model is performing on a class by class basis.

It seems like our model performs fairly poorly on classes like apple_pie and ravioli while for classes like edamame and pho the performance is quite outstanding.

Findings like these give us clues into where we could go next with our experiments. Perhaps we may have to collect more data on poor performing classes or perhaps the worst performing classes are just hard to make predictions on.

🛠 Exercise: Visualize some of the most poor performing classes, do you notice any trends among them?

Visualizing predictions on test images¶

Time for the real test. Visualizing predictions on actual images. You can look at all the metrics you want but until you've visualized some predictions, you won't really know how your model is performing.

As it stands, our model can't just predict on any image of our choice. The image first has to be loaded into a tensor.

So to begin predicting on any given image, we'll create a function to load an image into a tensor.

Specifically, it'll:

  • Read in a target image filepath using tf.io.read_file().
  • Turn the image into a Tensor using tf.io.decode_image().
  • Resize the image to be the same size as the images our model has been trained on (224 x 224) using tf.image.resize().
  • Scale the image to get all the pixel values between 0 & 1 if necessary.
In [43]:
def load_and_prep_image(filename, img_shape=224, scale=True):
  """
  Reads in an image from filename, turns it into a tensor and reshapes into
  (224, 224, 3).

  Parameters
  ----------
  filename (str): string filename of target image
  img_shape (int): size to resize target image to, default 224
  scale (bool): whether to scale pixel values to range(0, 1), default True
  """
  # Read in the image
  img = tf.io.read_file(filename)
  # Decode it into a tensor
  img = tf.io.decode_image(img)
  # Resize the image
  img = tf.image.resize(img, [img_shape, img_shape])
  if scale:
    # Rescale the image (get all values between 0 and 1)
    return img/255.
  else:
    return img

Image loading and preprocessing function ready.

Now let's write some code to:

  1. Load a few random images from the test dataset.
  2. Make predictions on them.
  3. Plot the original image(s) along with the model's predicted label, prediction probability and ground truth label.
In [44]:
# Make preds on a series of random images
import os
import random

plt.figure(figsize=(17, 10))
for i in range(3):
  # Choose a random image from a random class 
  class_name = random.choice(class_names)
  filename = random.choice(os.listdir(test_dir + "/" + class_name))
  filepath = test_dir + class_name + "/" + filename

  # Load the image and make predictions
  img = load_and_prep_image(filepath, scale=False) # don't scale images for EfficientNet predictions
  pred_prob = model.predict(tf.expand_dims(img, axis=0)) # model accepts tensors of shape [None, 224, 224, 3]
  pred_class = class_names[pred_prob.argmax()] # find the predicted class 

  # Plot the image(s)
  plt.subplot(1, 3, i+1)
  plt.imshow(img/255.)
  if class_name == pred_class: # Change the color of text based on whether prediction is right or wrong
    title_color = "g"
  else:
    title_color = "r"
  plt.title(f"actual: {class_name}, pred: {pred_class}, prob: {pred_prob.max():.2f}", c=title_color)
  plt.axis(False);

After going through enough random samples, it starts to become clear that the model tends to make far worse predictions on classes which are visually similar such as baby_back_ribs getting mistaken as steak and vice versa.

Finding the most wrong predictions¶

It's a good idea to go through at least 100+ random instances of your model's predictions to get a good feel for how it's doing.

After a while you might notice the model predicting on some images with a very high prediction probability, meaning it's very confident with its prediction but still getting the label wrong.

These most wrong predictions can help to give further insight into your model's performance.

So how about we write some code to collect all of the predictions where the model has output a high prediction probability for an image (e.g. 0.95+) but gotten the prediction wrong.

We'll go through the following steps:

  1. Get all of the image file paths in the test dataset using the list_files() method.
  2. Create a pandas DataFrame of the image filepaths, ground truth labels, prediction classes, max prediction probabilities, ground truth class names and predicted class names.
    • Note: We don't necessarily have to create a DataFrame like this but it'll help us visualize things as we go.
  3. Use our DataFrame to find all the wrong predictions (where the ground truth doesn't match the prediction).
  4. Sort the DataFrame based on wrong predictions and highest max prediction probabilities.
  5. Visualize the images with the highest prediction probabilities but have the wrong prediction.
In [45]:
# 1. Get the filenames of all of our test data
filepaths = []
for filepath in test_data.list_files("101_food_classes_10_percent/test/*/*.jpg", 
                                     shuffle=False):
  filepaths.append(filepath.numpy())
filepaths[:10]
Out[45]:
[b'101_food_classes_10_percent/test/apple_pie/1011328.jpg',
 b'101_food_classes_10_percent/test/apple_pie/101251.jpg',
 b'101_food_classes_10_percent/test/apple_pie/1034399.jpg',
 b'101_food_classes_10_percent/test/apple_pie/103801.jpg',
 b'101_food_classes_10_percent/test/apple_pie/1038694.jpg',
 b'101_food_classes_10_percent/test/apple_pie/1047447.jpg',
 b'101_food_classes_10_percent/test/apple_pie/1068632.jpg',
 b'101_food_classes_10_percent/test/apple_pie/110043.jpg',
 b'101_food_classes_10_percent/test/apple_pie/1106961.jpg',
 b'101_food_classes_10_percent/test/apple_pie/1113017.jpg']

Now we've got all of the test image filepaths, let's combine them into a DataFrame along with:

  • Their ground truth labels (y_labels).
  • The class the model predicted (pred_classes).
  • The maximum prediction probabilitity value (pred_probs.max(axis=1)).
  • The ground truth class names.
  • The predicted class names.
In [46]:
# 2. Create a dataframe out of current prediction data for analysis
import pandas as pd
pred_df = pd.DataFrame({"img_path": filepaths,
                        "y_true": y_labels,
                        "y_pred": pred_classes,
                        "pred_conf": pred_probs.max(axis=1), # get the maximum prediction probability value
                        "y_true_classname": [class_names[i] for i in y_labels],
                        "y_pred_classname": [class_names[i] for i in pred_classes]}) 
pred_df.head()
Out[46]:
img_path y_true y_pred pred_conf y_true_classname y_pred_classname
0 b'101_food_classes_10_percent/test/apple_pie/1... 0 52 0.847419 apple_pie gyoza
1 b'101_food_classes_10_percent/test/apple_pie/1... 0 0 0.964017 apple_pie apple_pie
2 b'101_food_classes_10_percent/test/apple_pie/1... 0 0 0.959259 apple_pie apple_pie
3 b'101_food_classes_10_percent/test/apple_pie/1... 0 80 0.658607 apple_pie pulled_pork_sandwich
4 b'101_food_classes_10_percent/test/apple_pie/1... 0 79 0.367901 apple_pie prime_rib

Nice! How about we make a simple column telling us whether or not the prediction is right or wrong?

In [47]:
# 3. Is the prediction correct?
pred_df["pred_correct"] = pred_df["y_true"] == pred_df["y_pred"]
pred_df.head()
Out[47]:
img_path y_true y_pred pred_conf y_true_classname y_pred_classname pred_correct
0 b'101_food_classes_10_percent/test/apple_pie/1... 0 52 0.847419 apple_pie gyoza False
1 b'101_food_classes_10_percent/test/apple_pie/1... 0 0 0.964017 apple_pie apple_pie True
2 b'101_food_classes_10_percent/test/apple_pie/1... 0 0 0.959259 apple_pie apple_pie True
3 b'101_food_classes_10_percent/test/apple_pie/1... 0 80 0.658607 apple_pie pulled_pork_sandwich False
4 b'101_food_classes_10_percent/test/apple_pie/1... 0 79 0.367901 apple_pie prime_rib False

And now since we know which predictions were right or wrong and along with their prediction probabilities, how about we get the 100 "most wrong" predictions by sorting for wrong predictions and descending prediction probabilties?

In [48]:
# 4. Get the top 100 wrong examples
top_100_wrong = pred_df[pred_df["pred_correct"] == False].sort_values("pred_conf", ascending=False)[:100]
top_100_wrong.head(20)
Out[48]:
img_path y_true y_pred pred_conf y_true_classname y_pred_classname pred_correct
21810 b'101_food_classes_10_percent/test/scallops/17... 87 29 0.999997 scallops cup_cakes False
231 b'101_food_classes_10_percent/test/apple_pie/8... 0 100 0.999995 apple_pie waffles False
15359 b'101_food_classes_10_percent/test/lobster_rol... 61 53 0.999988 lobster_roll_sandwich hamburger False
23539 b'101_food_classes_10_percent/test/strawberry_... 94 83 0.999987 strawberry_shortcake red_velvet_cake False
21400 b'101_food_classes_10_percent/test/samosa/3140... 85 92 0.999981 samosa spring_rolls False
24540 b'101_food_classes_10_percent/test/tiramisu/16... 98 83 0.999947 tiramisu red_velvet_cake False
2511 b'101_food_classes_10_percent/test/bruschetta/... 10 61 0.999945 bruschetta lobster_roll_sandwich False
5574 b'101_food_classes_10_percent/test/chocolate_m... 22 21 0.999939 chocolate_mousse chocolate_cake False
17855 b'101_food_classes_10_percent/test/paella/2314... 71 65 0.999931 paella mussels False
23797 b'101_food_classes_10_percent/test/sushi/16593... 95 86 0.999904 sushi sashimi False
18001 b'101_food_classes_10_percent/test/pancakes/10... 72 67 0.999904 pancakes omelette False
11642 b'101_food_classes_10_percent/test/garlic_brea... 46 10 0.999877 garlic_bread bruschetta False
10847 b'101_food_classes_10_percent/test/fried_calam... 43 68 0.999872 fried_calamari onion_rings False
23631 b'101_food_classes_10_percent/test/strawberry_... 94 83 0.999858 strawberry_shortcake red_velvet_cake False
1155 b'101_food_classes_10_percent/test/beef_tartar... 4 5 0.999858 beef_tartare beet_salad False
10854 b'101_food_classes_10_percent/test/fried_calam... 43 68 0.999854 fried_calamari onion_rings False
23904 b'101_food_classes_10_percent/test/sushi/33652... 95 86 0.999823 sushi sashimi False
7316 b'101_food_classes_10_percent/test/cup_cakes/1... 29 83 0.999816 cup_cakes red_velvet_cake False
13144 b'101_food_classes_10_percent/test/gyoza/31214... 52 92 0.999799 gyoza spring_rolls False
10880 b'101_food_classes_10_percent/test/fried_calam... 43 68 0.999778 fried_calamari onion_rings False

Very interesting... just by comparing the ground truth classname (y_true_classname) and the prediction classname column (y_pred_classname), do you notice any trends?

It might be easier if we visualize them.

In [49]:
# 5. Visualize some of the most wrong examples
images_to_view = 9
start_index = 10 # change the start index to view more
plt.figure(figsize=(15, 10))
for i, row in enumerate(top_100_wrong[start_index:start_index+images_to_view].itertuples()): 
  plt.subplot(3, 3, i+1)
  img = load_and_prep_image(row[1], scale=True)
  _, _, _, _, pred_prob, y_true, y_pred, _ = row # only interested in a few parameters of each row
  plt.imshow(img)
  plt.title(f"actual: {y_true}, pred: {y_pred} \nprob: {pred_prob:.2f}")
  plt.axis(False)

Going through the model's most wrong predictions can usually help figure out a couple of things:

  • Some of the labels might be wrong - If our model ends up being good enough, it may actually learning to predict very well on certain classes. This means some images which the model predicts the right label may show up as wrong if the ground truth label is wrong. If this is the case, we can often use our model to help us improve the labels in our dataset(s) and in turn, potentially making future models better. This process of using the model to help improve labels is often referred to as active learning.
  • Could more samples be collected? - If there's a recurring pattern for a certain class being poorly predicted on, perhaps it's a good idea to collect more samples of that particular class in different scenarios to improve further models.

Test out the big dog model on test images as well as custom images of food¶

So far we've visualized some our model's predictions from the test dataset but it's time for the real test: using our model to make predictions on our own custom images of food.

For this you might want to upload your own images to Google Colab or by putting them in a folder you can load into the notebook.

In my case, I've prepared my own small dataset of six or so images of various foods.

Let's download them and unzip them.

In [50]:
# Download some custom images from Google Storage
# Note: you can upload your own custom images to Google Colab using the "upload" button in the Files tab
!wget https://storage.googleapis.com/ztm_tf_course/food_vision/custom_food_images.zip

unzip_data("custom_food_images.zip") 
--2022-07-30 17:04:19--  https://storage.googleapis.com/ztm_tf_course/food_vision/custom_food_images.zip
Resolving storage.googleapis.com (storage.googleapis.com)... 142.251.6.128, 142.250.159.128, 142.251.120.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.251.6.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 13192985 (13M) [application/zip]
Saving to: ‘custom_food_images.zip’

custom_food_images. 100%[===================>]  12.58M  52.2MB/s    in 0.2s    

2022-07-30 17:04:19 (52.2 MB/s) - ‘custom_food_images.zip’ saved [13192985/13192985]

Wonderful, we can load these in and turn them into tensors using our load_and_prep_image() function but first we need a list of image filepaths.

In [51]:
# Get custom food images filepaths
custom_food_images = ["custom_food_images/" + img_path for img_path in os.listdir("custom_food_images")]
custom_food_images
Out[51]:
['custom_food_images/hamburger.jpeg',
 'custom_food_images/pizza-dad.jpeg',
 'custom_food_images/chicken_wings.jpeg',
 'custom_food_images/ramen.jpeg',
 'custom_food_images/sushi.jpeg',
 'custom_food_images/steak.jpeg']

Now we can use similar code to what we used previously to load in our images, make a prediction on each using our trained model and then plot the image along with the predicted class.

In [52]:
# Make predictions on custom food images
for img in custom_food_images:
  img = load_and_prep_image(img, scale=False) # load in target image and turn it into tensor
  pred_prob = model.predict(tf.expand_dims(img, axis=0)) # make prediction on image with shape [None, 224, 224, 3]
  pred_class = class_names[pred_prob.argmax()] # find the predicted class label
  # Plot the image with appropriate annotations
  plt.figure()
  plt.imshow(img/255.) # imshow() requires float inputs to be normalized
  plt.title(f"pred: {pred_class}, prob: {pred_prob.max():.2f}")
  plt.axis(False)

Two thumbs up! How cool is that?! Our Food Vision model has come to life!

Seeing a machine learning model work on a premade test dataset is cool but seeing it work on your own data is mind blowing.

And guess what... our model got these incredible results (10%+ better than the baseline) with only 10% of the training images.

I wonder what would happen if we trained a model with all of the data (100% of the training data from Food101 instead of 10%)? Hint: that's your task in the next notebook.

🛠 Exercises¶

  1. Take 3 of your own photos of food and use the trained model to make predictions on them, share your predictions with the other students in Discord and show off your Food Vision model 🍔👁.
  2. Train a feature-extraction transfer learning model for 10 epochs on the same data and compare its performance versus a model which used feature extraction for 5 epochs and fine-tuning for 5 epochs (like we've used in this notebook). Which method is better?
  3. Recreate our first model (the feature extraction model) with mixed_precision turned on.
    • Does it make the model train faster?
    • Does it effect the accuracy or performance of our model?
    • What's the advatanges of using mixed_precision training?

📖 Extra-curriculum¶

  • Spend 15-minutes reading up on the EarlyStopping callback. What does it do? How could we use it in our model training?
  • Spend an hour reading about Streamlit. What does it do? How might you integrate some of the things we've done in this notebook in a Streamlit app?